{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "from ResNet_hc import resnet101\n",
    "import time\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 利用torchvision对图像数据预处理\n",
    "train_transform = transforms.Compose([\n",
    "    transforms.Resize(256),\n",
    "    transforms.RandomResizedCrop(224),\n",
    "    transforms.RandomAffine(degrees=15,scale=(0.8,1.5)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "val_transform = transforms.Compose([\n",
    "    transforms.Resize(256),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "trainset = torchvision.datasets.ImageFolder(root='../data/train/', transform=train_transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)\n",
    "\n",
    "valset = torchvision.datasets.ImageFolder(root='../data/val/', transform=val_transform)\n",
    "valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "165 32\n"
     ]
    }
   ],
   "source": [
    "print(len(trainloader), len(valloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载预训练模型\n",
    "model = resnet101(2)\n",
    "model.load_state_dict(torch.load('pretrained/resnet101-5d3b4d8f.pth'), strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We are using 2 GPUs!\n",
      "epoch 0,iter 20,train accuracy: 42.1484%   loss:  0.7372\n",
      "epoch 0,iter 40,train accuracy: 69.4141%   loss:  0.6270\n",
      "epoch 0,iter 60,train accuracy: 81.4844%   loss:  0.5404\n",
      "epoch 0,iter 80,train accuracy: 87.0703%   loss:  0.4702\n",
      "epoch 0,iter 100,train accuracy: 88.3984%   loss:  0.4224\n",
      "epoch 0,iter 120,train accuracy: 88.7109%   loss:  0.3820\n",
      "epoch 0,iter 140,train accuracy: 89.2969%   loss:  0.3575\n",
      "epoch 0,iter 160,train accuracy: 89.7656%   loss:  0.3369\n",
      "waitting for Val...\n",
      "epoch 0  The ValSet accuracy is 97.3750% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 2m 30s\n",
      "Now the best val Acc is 97.3750%\n",
      "epoch 1,iter 20,train accuracy: 91.0547%   loss:  0.2939\n",
      "epoch 1,iter 40,train accuracy: 91.3672%   loss:  0.2861\n",
      "epoch 1,iter 60,train accuracy: 90.7031%   loss:  0.2738\n",
      "epoch 1,iter 80,train accuracy: 90.7812%   loss:  0.2635\n",
      "epoch 1,iter 100,train accuracy: 91.2891%   loss:  0.2533\n",
      "epoch 1,iter 120,train accuracy: 93.0469%   loss:  0.2327\n",
      "epoch 1,iter 140,train accuracy: 92.5391%   loss:  0.2311\n",
      "epoch 1,iter 160,train accuracy: 92.1484%   loss:  0.2218\n",
      "waitting for Val...\n",
      "epoch 1  The ValSet accuracy is 98.1500% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 5m 12s\n",
      "Now the best val Acc is 98.1500%\n",
      "epoch 2,iter 20,train accuracy: 91.3281%   loss:  0.2242\n",
      "epoch 2,iter 40,train accuracy: 92.0703%   loss:  0.2114\n",
      "epoch 2,iter 60,train accuracy: 92.2266%   loss:  0.2087\n",
      "epoch 2,iter 80,train accuracy: 92.3828%   loss:  0.2017\n",
      "epoch 2,iter 100,train accuracy: 92.9688%   loss:  0.2006\n",
      "epoch 2,iter 120,train accuracy: 93.5547%   loss:  0.1855\n",
      "epoch 2,iter 140,train accuracy: 92.1484%   loss:  0.1944\n",
      "epoch 2,iter 160,train accuracy: 92.5391%   loss:  0.1930\n",
      "waitting for Val...\n",
      "epoch 2  The ValSet accuracy is 98.3500% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 7m 50s\n",
      "Now the best val Acc is 98.3500%\n",
      "epoch 3,iter 20,train accuracy: 91.9141%   loss:  0.1910\n",
      "epoch 3,iter 40,train accuracy: 92.3438%   loss:  0.1946\n",
      "epoch 3,iter 60,train accuracy: 92.9688%   loss:  0.1774\n",
      "epoch 3,iter 80,train accuracy: 92.4609%   loss:  0.1849\n",
      "epoch 3,iter 100,train accuracy: 93.2422%   loss:  0.1740\n",
      "epoch 3,iter 120,train accuracy: 93.4375%   loss:  0.1673\n",
      "epoch 3,iter 140,train accuracy: 93.0469%   loss:  0.1690\n",
      "epoch 3,iter 160,train accuracy: 93.0469%   loss:  0.1776\n",
      "waitting for Val...\n",
      "epoch 3  The ValSet accuracy is 98.6000% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 10m 35s\n",
      "Now the best val Acc is 98.6000%\n",
      "epoch 4,iter 20,train accuracy: 93.2812%   loss:  0.1672\n",
      "epoch 4,iter 40,train accuracy: 93.9453%   loss:  0.1512\n",
      "epoch 4,iter 60,train accuracy: 92.5391%   loss:  0.1724\n",
      "epoch 4,iter 80,train accuracy: 93.3203%   loss:  0.1587\n",
      "epoch 4,iter 100,train accuracy: 93.8281%   loss:  0.1557\n",
      "epoch 4,iter 120,train accuracy: 93.8281%   loss:  0.1543\n",
      "epoch 4,iter 140,train accuracy: 93.4766%   loss:  0.1528\n",
      "epoch 4,iter 160,train accuracy: 93.6328%   loss:  0.1552\n",
      "waitting for Val...\n",
      "epoch 4  The ValSet accuracy is 98.7000% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 13m 13s\n",
      "Now the best val Acc is 98.7000%\n",
      "epoch 5,iter 20,train accuracy: 94.5703%   loss:  0.1435\n",
      "epoch 5,iter 40,train accuracy: 94.7266%   loss:  0.1427\n",
      "epoch 5,iter 60,train accuracy: 94.0234%   loss:  0.1458\n",
      "epoch 5,iter 80,train accuracy: 93.6328%   loss:  0.1591\n",
      "epoch 5,iter 100,train accuracy: 93.7891%   loss:  0.1464\n",
      "epoch 5,iter 120,train accuracy: 93.8672%   loss:  0.1428\n",
      "epoch 5,iter 140,train accuracy: 94.5312%   loss:  0.1411\n",
      "epoch 5,iter 160,train accuracy: 94.1406%   loss:  0.1386\n",
      "waitting for Val...\n",
      "epoch 5  The ValSet accuracy is 98.8250% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 15m 46s\n",
      "Now the best val Acc is 98.8250%\n",
      "epoch 6,iter 20,train accuracy: 93.7891%   loss:  0.1397\n",
      "epoch 6,iter 40,train accuracy: 94.2188%   loss:  0.1372\n",
      "epoch 6,iter 60,train accuracy: 93.9844%   loss:  0.1464\n",
      "epoch 6,iter 80,train accuracy: 94.1797%   loss:  0.1401\n",
      "epoch 6,iter 100,train accuracy: 94.1016%   loss:  0.1474\n",
      "epoch 6,iter 120,train accuracy: 94.9609%   loss:  0.1314\n",
      "epoch 6,iter 140,train accuracy: 94.1797%   loss:  0.1360\n",
      "epoch 6,iter 160,train accuracy: 93.8281%   loss:  0.1442\n",
      "waitting for Val...\n",
      "epoch 6  The ValSet accuracy is 98.8000% \n",
      "\n",
      "Training complete in 18m 26s\n",
      "Now the best val Acc is 98.8250%\n",
      "epoch 7,iter 20,train accuracy: 93.3203%   loss:  0.1465\n",
      "epoch 7,iter 40,train accuracy: 93.7109%   loss:  0.1408\n",
      "epoch 7,iter 60,train accuracy: 94.4531%   loss:  0.1334\n",
      "epoch 7,iter 80,train accuracy: 93.7500%   loss:  0.1421\n",
      "epoch 7,iter 100,train accuracy: 94.2969%   loss:  0.1390\n",
      "epoch 7,iter 120,train accuracy: 94.1406%   loss:  0.1366\n",
      "epoch 7,iter 140,train accuracy: 94.4922%   loss:  0.1237\n",
      "epoch 7,iter 160,train accuracy: 93.7500%   loss:  0.1400\n",
      "waitting for Val...\n",
      "epoch 7  The ValSet accuracy is 98.8500% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 21m 8s\n",
      "Now the best val Acc is 98.8500%\n",
      "epoch 8,iter 20,train accuracy: 94.8828%   loss:  0.1285\n",
      "epoch 8,iter 40,train accuracy: 94.4922%   loss:  0.1374\n",
      "epoch 8,iter 60,train accuracy: 94.1406%   loss:  0.1375\n",
      "epoch 8,iter 80,train accuracy: 94.6484%   loss:  0.1290\n",
      "epoch 8,iter 100,train accuracy: 94.6484%   loss:  0.1258\n",
      "epoch 8,iter 120,train accuracy: 94.3750%   loss:  0.1347\n",
      "epoch 8,iter 140,train accuracy: 94.8828%   loss:  0.1194\n",
      "epoch 8,iter 160,train accuracy: 94.5312%   loss:  0.1298\n",
      "waitting for Val...\n",
      "epoch 8  The ValSet accuracy is 98.8000% \n",
      "\n",
      "Training complete in 23m 45s\n",
      "Now the best val Acc is 98.8500%\n",
      "epoch 9,iter 20,train accuracy: 94.8828%   loss:  0.1322\n",
      "epoch 9,iter 40,train accuracy: 95.1953%   loss:  0.1175\n",
      "epoch 9,iter 60,train accuracy: 94.8047%   loss:  0.1242\n",
      "epoch 9,iter 80,train accuracy: 93.8672%   loss:  0.1458\n",
      "epoch 9,iter 100,train accuracy: 94.9609%   loss:  0.1168\n",
      "epoch 9,iter 120,train accuracy: 94.9219%   loss:  0.1187\n",
      "epoch 9,iter 140,train accuracy: 94.7656%   loss:  0.1184\n",
      "epoch 9,iter 160,train accuracy: 95.2344%   loss:  0.1202\n",
      "waitting for Val...\n",
      "epoch 9  The ValSet accuracy is 98.8750% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 26m 23s\n",
      "Now the best val Acc is 98.8750%\n",
      "epoch 10,iter 20,train accuracy: 94.1797%   loss:  0.1366\n",
      "epoch 10,iter 40,train accuracy: 94.5312%   loss:  0.1281\n",
      "epoch 10,iter 60,train accuracy: 94.5312%   loss:  0.1355\n",
      "epoch 10,iter 80,train accuracy: 94.6875%   loss:  0.1304\n",
      "epoch 10,iter 100,train accuracy: 94.8828%   loss:  0.1150\n",
      "epoch 10,iter 120,train accuracy: 94.9219%   loss:  0.1183\n",
      "epoch 10,iter 140,train accuracy: 94.3750%   loss:  0.1247\n",
      "epoch 10,iter 160,train accuracy: 94.8438%   loss:  0.1212\n",
      "waitting for Val...\n",
      "epoch 10  The ValSet accuracy is 98.9500% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 29m 1s\n",
      "Now the best val Acc is 98.9500%\n",
      "epoch 11,iter 20,train accuracy: 94.6484%   loss:  0.1208\n",
      "epoch 11,iter 40,train accuracy: 95.1953%   loss:  0.1146\n",
      "epoch 11,iter 60,train accuracy: 95.4688%   loss:  0.1193\n",
      "epoch 11,iter 80,train accuracy: 95.6250%   loss:  0.1181\n",
      "epoch 11,iter 100,train accuracy: 94.8828%   loss:  0.1225\n",
      "epoch 11,iter 120,train accuracy: 95.2344%   loss:  0.1105\n",
      "epoch 11,iter 140,train accuracy: 94.8438%   loss:  0.1249\n",
      "epoch 11,iter 160,train accuracy: 95.4297%   loss:  0.1137\n",
      "waitting for Val...\n",
      "epoch 11  The ValSet accuracy is 99.0000% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 31m 35s\n",
      "Now the best val Acc is 99.0000%\n",
      "epoch 12,iter 20,train accuracy: 94.6094%   loss:  0.1195\n",
      "epoch 12,iter 40,train accuracy: 94.8438%   loss:  0.1208\n",
      "epoch 12,iter 60,train accuracy: 95.1172%   loss:  0.1106\n",
      "epoch 12,iter 80,train accuracy: 95.1562%   loss:  0.1104\n",
      "epoch 12,iter 100,train accuracy: 95.5469%   loss:  0.1123\n",
      "epoch 12,iter 120,train accuracy: 95.2344%   loss:  0.1150\n",
      "epoch 12,iter 140,train accuracy: 94.6094%   loss:  0.1314\n",
      "epoch 12,iter 160,train accuracy: 94.6484%   loss:  0.1171\n",
      "waitting for Val...\n",
      "epoch 12  The ValSet accuracy is 98.9250% \n",
      "\n",
      "Training complete in 34m 13s\n",
      "Now the best val Acc is 99.0000%\n",
      "epoch 13,iter 20,train accuracy: 94.8828%   loss:  0.1177\n",
      "epoch 13,iter 40,train accuracy: 95.1953%   loss:  0.1188\n",
      "epoch 13,iter 60,train accuracy: 95.0781%   loss:  0.1166\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13,iter 80,train accuracy: 95.3516%   loss:  0.1122\n",
      "epoch 13,iter 100,train accuracy: 95.3125%   loss:  0.1139\n",
      "epoch 13,iter 120,train accuracy: 94.8438%   loss:  0.1235\n",
      "epoch 13,iter 140,train accuracy: 95.0000%   loss:  0.1124\n",
      "epoch 13,iter 160,train accuracy: 95.1172%   loss:  0.1112\n",
      "waitting for Val...\n",
      "epoch 13  The ValSet accuracy is 98.9750% \n",
      "\n",
      "Training complete in 36m 47s\n",
      "Now the best val Acc is 99.0000%\n",
      "epoch 14,iter 20,train accuracy: 94.9219%   loss:  0.1152\n",
      "epoch 14,iter 40,train accuracy: 94.8828%   loss:  0.1179\n",
      "epoch 14,iter 60,train accuracy: 95.4297%   loss:  0.1136\n",
      "epoch 14,iter 80,train accuracy: 95.5859%   loss:  0.1086\n",
      "epoch 14,iter 100,train accuracy: 94.9609%   loss:  0.1201\n",
      "epoch 14,iter 120,train accuracy: 94.6094%   loss:  0.1230\n",
      "epoch 14,iter 140,train accuracy: 94.6094%   loss:  0.1161\n",
      "epoch 14,iter 160,train accuracy: 95.6250%   loss:  0.0998\n",
      "waitting for Val...\n",
      "epoch 14  The ValSet accuracy is 98.9000% \n",
      "\n",
      "Training complete in 39m 18s\n",
      "Now the best val Acc is 99.0000%\n",
      "epoch 15,iter 20,train accuracy: 95.5078%   loss:  0.1096\n",
      "epoch 15,iter 40,train accuracy: 95.5469%   loss:  0.1016\n",
      "epoch 15,iter 60,train accuracy: 94.8047%   loss:  0.1104\n",
      "epoch 15,iter 80,train accuracy: 95.2344%   loss:  0.1107\n",
      "epoch 15,iter 100,train accuracy: 95.4297%   loss:  0.1096\n",
      "epoch 15,iter 120,train accuracy: 95.6250%   loss:  0.1082\n",
      "epoch 15,iter 140,train accuracy: 94.8047%   loss:  0.1157\n",
      "epoch 15,iter 160,train accuracy: 95.7031%   loss:  0.1094\n",
      "waitting for Val...\n",
      "epoch 15  The ValSet accuracy is 99.0250% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 41m 49s\n",
      "Now the best val Acc is 99.0250%\n",
      "epoch 16,iter 20,train accuracy: 94.9219%   loss:  0.1102\n",
      "epoch 16,iter 40,train accuracy: 95.4688%   loss:  0.1080\n",
      "epoch 16,iter 60,train accuracy: 95.0391%   loss:  0.1108\n",
      "epoch 16,iter 80,train accuracy: 94.9609%   loss:  0.1125\n",
      "epoch 16,iter 100,train accuracy: 95.2344%   loss:  0.1129\n",
      "epoch 16,iter 120,train accuracy: 95.5859%   loss:  0.1034\n",
      "epoch 16,iter 140,train accuracy: 95.1953%   loss:  0.1148\n",
      "epoch 16,iter 160,train accuracy: 95.8203%   loss:  0.1025\n",
      "waitting for Val...\n",
      "epoch 16  The ValSet accuracy is 99.0750% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 44m 30s\n",
      "Now the best val Acc is 99.0750%\n",
      "epoch 17,iter 20,train accuracy: 94.8828%   loss:  0.1159\n",
      "epoch 17,iter 40,train accuracy: 95.3125%   loss:  0.1064\n",
      "epoch 17,iter 60,train accuracy: 95.3906%   loss:  0.1061\n",
      "epoch 17,iter 80,train accuracy: 95.3906%   loss:  0.1134\n",
      "epoch 17,iter 100,train accuracy: 95.8594%   loss:  0.1047\n",
      "epoch 17,iter 120,train accuracy: 95.3125%   loss:  0.1141\n",
      "epoch 17,iter 140,train accuracy: 96.0938%   loss:  0.0981\n",
      "epoch 17,iter 160,train accuracy: 94.7656%   loss:  0.1181\n",
      "waitting for Val...\n",
      "epoch 17  The ValSet accuracy is 99.1500% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 47m 7s\n",
      "Now the best val Acc is 99.1500%\n",
      "epoch 18,iter 20,train accuracy: 95.4297%   loss:  0.1036\n",
      "epoch 18,iter 40,train accuracy: 95.3125%   loss:  0.1103\n",
      "epoch 18,iter 60,train accuracy: 96.0938%   loss:  0.0911\n",
      "epoch 18,iter 80,train accuracy: 94.5312%   loss:  0.1203\n",
      "epoch 18,iter 100,train accuracy: 96.2500%   loss:  0.0958\n",
      "epoch 18,iter 120,train accuracy: 95.3906%   loss:  0.1100\n",
      "epoch 18,iter 140,train accuracy: 95.8203%   loss:  0.1018\n",
      "epoch 18,iter 160,train accuracy: 95.6641%   loss:  0.1120\n",
      "waitting for Val...\n",
      "epoch 18  The ValSet accuracy is 99.1250% \n",
      "\n",
      "Training complete in 49m 38s\n",
      "Now the best val Acc is 99.1500%\n",
      "epoch 19,iter 20,train accuracy: 95.8203%   loss:  0.0927\n",
      "epoch 19,iter 40,train accuracy: 95.2734%   loss:  0.1075\n",
      "epoch 19,iter 60,train accuracy: 95.8203%   loss:  0.1000\n",
      "epoch 19,iter 80,train accuracy: 95.4688%   loss:  0.1005\n",
      "epoch 19,iter 100,train accuracy: 96.0547%   loss:  0.0983\n",
      "epoch 19,iter 120,train accuracy: 95.3125%   loss:  0.1095\n",
      "epoch 19,iter 140,train accuracy: 96.3672%   loss:  0.0922\n",
      "epoch 19,iter 160,train accuracy: 95.8203%   loss:  0.0997\n",
      "waitting for Val...\n",
      "epoch 19  The ValSet accuracy is 99.1750% \n",
      "\n",
      "Find Better Model and Saving it...\n",
      "Saved!\n",
      "Training complete in 52m 13s\n",
      "Now the best val Acc is 99.1750%\n"
     ]
    }
   ],
   "source": [
    "# CPU 或者 GPU\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "# 查看GPU可用情况\n",
    "if torch.cuda.device_count()>1:\n",
    "    print('We are using',torch.cuda.device_count(),'GPUs!')\n",
    "    model = nn.DataParallel(model)\n",
    "model.to(device)\n",
    "\n",
    "# 定义loss function和优化器\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)\n",
    "\n",
    "# 保存每个epoch后的Accuracy Loss Val_Accuracy\n",
    "Accuracy = []\n",
    "Loss = []\n",
    "Val_Accuracy = []\n",
    "BEST_VAL_ACC = 0.\n",
    "# 训练\n",
    "since = time.time()\n",
    "for epoch in range(20):\n",
    "    train_loss = 0.\n",
    "    train_accuracy = 0.\n",
    "    run_accuracy = 0.\n",
    "    run_loss =0.\n",
    "    total = 0.\n",
    "    model.train()\n",
    "    for i,data in enumerate(trainloader,0):\n",
    "        images, labels = data\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)  \n",
    "        # 经典四步\n",
    "        optimizer.zero_grad()\n",
    "        outs = model(images)\n",
    "        loss = criterion(outs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        # 输出状态\n",
    "        total += labels.size(0)\n",
    "        run_loss += loss.item()\n",
    "        _,prediction = torch.max(outs,1)\n",
    "        run_accuracy += (prediction == labels).sum().item()\n",
    "        if i % 20 == 19:\n",
    "            print('epoch {},iter {},train accuracy: {:.4f}%   loss:  {:.4f}'.format(epoch, i+1, 100*run_accuracy/(labels.size(0)*20), run_loss/20))\n",
    "            train_accuracy += run_accuracy\n",
    "            train_loss += run_loss\n",
    "            run_accuracy, run_loss = 0., 0.\n",
    "    Loss.append(train_loss/total)\n",
    "    Accuracy.append(100*train_accuracy/total)\n",
    "    # 可视化训练过程\n",
    "    fig1, ax1 = plt.subplots(figsize=(11, 8))\n",
    "    ax1.plot(range(0, epoch+1, 1), Accuracy)\n",
    "    ax1.set_title(\"Average trainset accuracy vs epochs\")\n",
    "    ax1.set_xlabel(\"Epoch\")\n",
    "    ax1.set_ylabel(\"Avg. train. accuracy\")\n",
    "    plt.savefig('Train_accuracy_vs_epochs.png')\n",
    "    plt.clf()\n",
    "    plt.close()\n",
    "    \n",
    "    fig2, ax2 = plt.subplots(figsize=(11, 8))\n",
    "    ax2.plot(range(epoch+1), Loss)\n",
    "    ax2.set_title(\"Average trainset loss vs epochs\")\n",
    "    ax2.set_xlabel(\"Epoch\")\n",
    "    ax2.set_ylabel(\"Current loss\")\n",
    "    plt.savefig('loss_vs_epochs.png')\n",
    "\n",
    "    plt.clf()\n",
    "    plt.close()\n",
    "    # 验证\n",
    "    acc = 0.\n",
    "    model.eval()\n",
    "    print('waitting for Val...')\n",
    "    with torch.no_grad():\n",
    "        accuracy = 0.\n",
    "        total =0\n",
    "        for data in valloader:\n",
    "            images, labels = data\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device)\n",
    "            out = model(images)\n",
    "            _, prediction = torch.max(out, 1)\n",
    "            total += labels.size(0)\n",
    "            accuracy += (prediction == labels).sum().item()\n",
    "            acc = 100.*accuracy/total\n",
    "    print('epoch {}  The ValSet accuracy is {:.4f}% \\n'.format(epoch, acc))\n",
    "    Val_Accuracy.append(acc)\n",
    "    if acc > BEST_VAL_ACC:\n",
    "        print('Find Better Model and Saving it...')\n",
    "        if not os.path.isdir('checkpoint'):\n",
    "            os.mkdir('checkpoint')\n",
    "        torch.save(model.state_dict(), './checkpoint/ResNet101_Cats_Dogs_hc.pth')\n",
    "        BEST_VAL_ACC = acc\n",
    "        print('Saved!')\n",
    "    \n",
    "    fig3, ax3 = plt.subplots(figsize=(11, 8))\n",
    "\n",
    "    ax3.plot(range(epoch+1),Val_Accuracy )\n",
    "    ax3.set_title(\"Average Val accuracy vs epochs\")\n",
    "    ax3.set_xlabel(\"Epoch\")\n",
    "    ax3.set_ylabel(\"Current Val accuracy\")\n",
    "\n",
    "    plt.savefig('val_accuracy_vs_epoch.png')\n",
    "    plt.close()\n",
    "    time_elapsed = time.time() - since\n",
    "    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed%60))\n",
    "    print('Now the best val Acc is {:.4f}%'.format(BEST_VAL_ACC))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:Funny] *",
   "language": "python",
   "name": "conda-env-Funny-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
